"""
Batch AB (EIB/Attentional Blink) feature extraction — UPDATED
Adds NEW TF windows:
  - EARLY: 100–200 ms
  - LATE:  400–500 ms

Outputs per trial:
  Factors:
    subject, event_code, threat_type (T/G/N), target_present (0/1), correct (1/0)
  ERP features (Volts):
    erp_epn_post (180–300 ms; P3/P4/POz)
    erp_pretgt_post (250–375 ms; P3/P4/POz)
    erp_tgtwin_post (375–500 ms; P3/P4/POz)
    erp_pretgt_fm (250–375 ms; Fz/Cz)
  TF features (dB, baseline -200..0 ms):
    post_alpha_db_{early, pretgt, late, tgt}
    post_theta_db_{early, pretgt, late, tgt}
    fm_theta_db_{early, pretgt, late, tgt}
  TBR features (log theta - log beta; computed from ONE linear-power TFR):
    fm_tbr_log_{early, pretgt, late, tgt}

Channels assumed present:
  POz, Fz, Cz, C3, C4, F3, F4, P3, P4

How to run:
  Activate your conda env, then:
    python batch_ab_features.py
  Then select:
    1) Correct folder
    2) Incorrect folder
    3) Output folder

Notes:
- This script is optimized vs earlier versions:
  - Computes Morlet TFR ONCE per ROI per subject (fm + posterior), then slices all windows/bands.
  - Computes TBR from the *same* fm linear-power TFR (no extra TFR calls).
"""

import os
import re
import glob
import numpy as np
import pandas as pd
import mne
from tkinter import filedialog, Tk
from datetime import datetime

# -----------------------------
# Configuration
# -----------------------------
ALLOWED_CODES = {"CT", "T", "CG", "G", "N", "CN"}

# ROIs (channels you provided)
CH_FM = ["Fz", "Cz"]
CH_POST = ["P3", "P4", "POz"]

# ERP windows (seconds; distractor-locked)
WIN_EPN_POST = (0.18, 0.30)
WIN_PRETGT = (0.25, 0.375)
WIN_TGTWIN_POST = (0.375, 0.50)

# NEW TF windows requested
WIN_EARLY = (0.100, 0.200)
WIN_LATE = (0.400, 0.500)

# Keep these too (useful for mechanism)
WIN_PRETGT_TF = (0.250, 0.375)
WIN_TGT_TF = (0.375, 0.500)

# TF baseline (seconds)
BASELINE = (-0.2, 0.0)

# Bands (Hz)
THETA = (4, 7)
ALPHA = (8, 12)
BETA = (13, 30)

# Morlet TFR config (4..30 Hz)
FREQS = np.arange(4, 31, 1)
N_CYCLES = FREQS / 2.0  # stable default

# Performance knobs
DECIM = 1        # set to 2 for speed-up
N_JOBS = 1       # increase if you want parallelism

# Output
PER_SUBJECT_CSV_SUFFIX = "_AB_trial_features.csv"
MERGED_CSV_NAME = "AB_all_subjects_trial_features.csv"


# -----------------------------
# GUI / IO
# -----------------------------
def pick_folders_and_outdir():
    root = Tk()
    root.withdraw()
    correct_dir = filedialog.askdirectory(title="Select folder containing Correct epochs .fif files")
    incorrect_dir = filedialog.askdirectory(title="Select folder containing Incorrect epochs .fif files")
    out_dir = filedialog.askdirectory(title="Select output folder for CSVs")
    return correct_dir, incorrect_dir, out_dir


# -----------------------------
# Label + parsing helpers
# -----------------------------
def normalize_event_label(lbl: str) -> str:
    return str(lbl).strip().upper().replace(" ", "")


def inv_event_map(epochs) -> np.ndarray:
    """Return array of event labels per epoch."""
    if epochs.event_id is None or len(epochs.event_id) == 0:
        raise RuntimeError("epochs.event_id is empty; cannot map events to labels.")
    inv = {v: k for k, v in epochs.event_id.items()}
    labels = np.array([inv[int(x)] for x in epochs.events[:, 2]], dtype=object)
    labels = np.array([normalize_event_label(x) for x in labels], dtype=object)
    return labels


def code_to_factors(code: str):
    code = normalize_event_label(code)
    threat = code.replace("C", "")  # CT->T, CG->G, CN->N
    target_present = 0 if code.startswith("C") else 1
    return threat, target_present


def subject_key_from_filename(path: str) -> str:
    """
    Robust keying: strips EEG/FILTERED/CORRECT/INCORRECT tokens and normalizes separators.
    Works with:
      'Ashley EEG Filtered_Correct.fif'
      'Ashley Filtered Incorrect.fif'
      etc.
    """
    base = os.path.splitext(os.path.basename(path))[0]
    base = re.sub(r"[_\-]+", " ", base)
    base = re.sub(r"\s+", " ", base).strip()

    tokens_to_remove = [
        r"\bEEG\b",
        r"\bEPOCHS?\b",
        r"\bFILTERED\b",
        r"\bCORRECT\b",
        r"\bINCORRECT\b",
    ]
    for tok in tokens_to_remove:
        base = re.sub(tok, "", base, flags=re.IGNORECASE)

    base = re.sub(r"\s+", " ", base).strip()
    return base


def build_file_map(folder: str):
    files = glob.glob(os.path.join(folder, "*.fif"))
    m = {}
    for f in files:
        k = subject_key_from_filename(f)
        if not k:
            continue
        # if duplicates, keep the newest modified
        if (k not in m) or (os.path.getmtime(f) > os.path.getmtime(m[k])):
            m[k] = f
    return m


# -----------------------------
# Feature extraction helpers
# -----------------------------
def mean_amp(epochs, picks, tmin, tmax):
    data = epochs.get_data(picks=picks)  # (n_epochs, n_ch, n_times)
    times = epochs.times
    idx = np.where((times >= tmin) & (times <= tmax))[0]
    if idx.size == 0:
        raise RuntimeError(f"No samples in ERP window {tmin}-{tmax}. Check epoch times.")
    return data[:, :, idx].mean(axis=(1, 2))


def band_mask(freqs, band):
    lo, hi = band
    return (freqs >= lo) & (freqs <= hi)


def compute_tfr_linear(epochs, picks, freqs, n_cycles):
    """
    Compute Morlet TFR linear power per epoch (average=False), averaged across channels.
    Returns:
      power: (n_epochs, n_freqs, n_times) linear power
      times: (n_times,)
      freqs: (n_freqs,)
    """
    ep = epochs.copy().pick(picks)
    tfr = mne.time_frequency.tfr_morlet(
        ep,
        freqs=freqs,
        n_cycles=n_cycles,
        use_fft=True,
        return_itc=False,
        average=False,
        decim=DECIM,
        n_jobs=N_JOBS,
        verbose="ERROR",
    )
    power = tfr.data.mean(axis=1)  # avg across channels -> (n_epochs, n_freqs, n_times)
    return power, tfr.times, tfr.freqs


def baseline_indices(times, baseline):
    b0, b1 = baseline
    idx = np.where((times >= b0) & (times <= b1))[0]
    if idx.size == 0:
        raise RuntimeError(f"No samples in baseline window {baseline}. Check epoch times.")
    return idx


def window_indices(times, win):
    t0, t1 = win
    idx = np.where((times >= t0) & (times <= t1))[0]
    if idx.size == 0:
        raise RuntimeError(f"No samples in TF window {win}. Check epoch times.")
    return idx


def power_db_from_linear(power, times, baseline):
    """
    Convert linear power to baseline-corrected dB:
      10*log10(power / baseline_power)
    """
    b_idx = baseline_indices(times, baseline)
    base = power[:, :, b_idx].mean(axis=2, keepdims=True)
    base = np.maximum(base, np.finfo(float).tiny)
    power = np.maximum(power, np.finfo(float).tiny)
    return 10.0 * np.log10(power / base)


def window_band_mean(power_db, times, freqs, band, win):
    fmask = band_mask(freqs, band)
    tidx = window_indices(times, win)
    return power_db[:, fmask, :][:, :, tidx].mean(axis=(1, 2))


def window_log_tbr_from_linear(power_lin, times, freqs, baseline, win):
    """
    Compute log TBR = log(theta_rel) - log(beta_rel), using baseline-normalized linear power.
    """
    b_idx = baseline_indices(times, baseline)
    base = power_lin[:, :, b_idx].mean(axis=2, keepdims=True)
    base = np.maximum(base, np.finfo(float).tiny)
    power_lin = np.maximum(power_lin, np.finfo(float).tiny)
    rel = power_lin / base  # linear baseline-relative power

    tidx = window_indices(times, win)

    theta_mask = band_mask(freqs, THETA)
    beta_mask = band_mask(freqs, BETA)

    theta_val = rel[:, theta_mask, :][:, :, tidx].mean(axis=(1, 2))
    beta_val = rel[:, beta_mask, :][:, :, tidx].mean(axis=(1, 2))

    theta_val = np.maximum(theta_val, np.finfo(float).tiny)
    beta_val = np.maximum(beta_val, np.finfo(float).tiny)

    return np.log(theta_val) - np.log(beta_val)


# -----------------------------
# Per-subject extraction
# -----------------------------
def extract_features_for_subject(subj_key: str, correct_path: str, incorrect_path: str) -> pd.DataFrame:
    ep_c = mne.read_epochs(correct_path, preload=True, verbose="ERROR")
    ep_i = mne.read_epochs(incorrect_path, preload=True, verbose="ERROR")

    # Validate channels exist
    missing = [ch for ch in (CH_FM + CH_POST) if ch not in ep_c.ch_names]
    if missing:
        raise RuntimeError(f"{subj_key}: missing required channels: {missing}")

    # Concatenate
    epochs = mne.concatenate_epochs([ep_c, ep_i])
    n_c, n_i = len(ep_c), len(ep_i)
    correct_flags = np.array([1]*n_c + [0]*n_i, dtype=int)

    # Condition codes per epoch
    epoch_codes = inv_event_map(epochs)

    # Filter to expected codes
    keep = np.array([c in ALLOWED_CODES for c in epoch_codes], dtype=bool)
    if not keep.all():
        epochs = epochs[keep]
        epoch_codes = epoch_codes[keep]
        correct_flags = correct_flags[keep]

    # Base trial table
    df = pd.DataFrame({
        "subject": subj_key,
        "event_code": epoch_codes,
        "correct": correct_flags,
    })
    df[["threat_type", "target_present"]] = df["event_code"].apply(
        lambda x: pd.Series(code_to_factors(x))
    )

    # Picks
    picks_fm = mne.pick_channels(epochs.ch_names, include=CH_FM)
    picks_post = mne.pick_channels(epochs.ch_names, include=CH_POST)

    # ERP features
    df["erp_epn_post"] = mean_amp(epochs, picks_post, *WIN_EPN_POST)
    df["erp_pretgt_post"] = mean_amp(epochs, picks_post, *WIN_PRETGT)
    df["erp_tgtwin_post"] = mean_amp(epochs, picks_post, *WIN_TGTWIN_POST)
    df["erp_pretgt_fm"] = mean_amp(epochs, picks_fm, *WIN_PRETGT)

    # Compute TFR ONCE per ROI (linear power)
    # fm ROI
    fm_lin, tf_times, tf_freqs = compute_tfr_linear(epochs, picks_fm, FREQS, N_CYCLES)
    fm_db = power_db_from_linear(fm_lin, tf_times, BASELINE)

    # posterior ROI
    post_lin, tf_times2, tf_freqs2 = compute_tfr_linear(epochs, picks_post, FREQS, N_CYCLES)
    post_db = power_db_from_linear(post_lin, tf_times2, BASELINE)

    if not np.allclose(tf_times, tf_times2) or not np.allclose(tf_freqs, tf_freqs2):
        raise RuntimeError(f"{subj_key}: TF grids differ between ROIs; unexpected.")

    # Window registry
    WINS = {
        "early": WIN_EARLY,
        "pretgt": WIN_PRETGT_TF,
        "late": WIN_LATE,
        "tgt": WIN_TGT_TF,
    }

    # Posterior alpha/theta dB per window
    for wname, win in WINS.items():
        df[f"post_alpha_db_{wname}"] = window_band_mean(post_db, tf_times, tf_freqs, ALPHA, win)
        df[f"post_theta_db_{wname}"] = window_band_mean(post_db, tf_times, tf_freqs, THETA, win)

    # Frontal-midline theta dB per window
    for wname, win in WINS.items():
        df[f"fm_theta_db_{wname}"] = window_band_mean(fm_db, tf_times, tf_freqs, THETA, win)

    # TBR (log theta - log beta) per window using fm linear power (NO extra TFR calls)
    for wname, win in WINS.items():
        df[f"fm_tbr_log_{wname}"] = window_log_tbr_from_linear(fm_lin, tf_times, tf_freqs, BASELINE, win)

    return df


# -----------------------------
# Main batch runner
# -----------------------------
def main():
    correct_dir, incorrect_dir, out_dir = pick_folders_and_outdir()
    if not correct_dir or not incorrect_dir or not out_dir:
        print("Canceled.")
        return

    print("Correct folder:", correct_dir)
    print("Incorrect folder:", incorrect_dir)
    print("Output folder:", out_dir)

    correct_map = build_file_map(correct_dir)
    incorrect_map = build_file_map(incorrect_dir)

    all_keys = sorted(set(correct_map.keys()) | set(incorrect_map.keys()))
    print(f"Detected subject keys (union): {len(all_keys)}")

    matched = []
    missing = []
    for k in all_keys:
        if k in correct_map and k in incorrect_map:
            matched.append(k)
        else:
            missing.append(k)

    if missing:
        print("\nWARNING: Some subjects are missing a pair (correct or incorrect).")
        for k in missing:
            print("  Missing pair for:", k,
                  "| correct:", k in correct_map,
                  "| incorrect:", k in incorrect_map)

    if not matched:
        print("No matched subject pairs found. Check filenames / folders.")
        return

    print(f"\nMatched subject pairs: {len(matched)}")

    all_dfs = []
    errors = []

    for idx, subj_key in enumerate(matched, start=1):
        c_path = correct_map[subj_key]
        i_path = incorrect_map[subj_key]
        print(f"\n[{idx}/{len(matched)}] Processing: {subj_key}")
        print("  Correct:", os.path.basename(c_path))
        print("  Incorrect:", os.path.basename(i_path))

        try:
            df_subj = extract_features_for_subject(subj_key, c_path, i_path)
            per_path = os.path.join(out_dir, f"{subj_key}{PER_SUBJECT_CSV_SUFFIX}")
            df_subj.to_csv(per_path, index=False)
            print("  Saved per-subject CSV:", per_path)
            all_dfs.append(df_subj)
        except Exception as e:
            msg = f"{subj_key}: {repr(e)}"
            print("  ERROR:", msg)
            errors.append(msg)

    if not all_dfs:
        print("\nNo subjects processed successfully.")
        if errors:
            print("\nErrors:")
            for m in errors:
                print(" ", m)
        return

    df_all = pd.concat(all_dfs, ignore_index=True)

    merged_path = os.path.join(out_dir, MERGED_CSV_NAME)
    df_all.to_csv(merged_path, index=False)
    print("\nSaved merged CSV:", merged_path)

    # quick summaries
    print("\nSummary counts (event_code x correct):")
    print(df_all.groupby(["event_code", "correct"]).size())

    print("\nSummary counts (threat_type x target_present x correct):")
    print(df_all.groupby(["threat_type", "target_present", "correct"]).size())

    # Save a log file
    log_path = os.path.join(out_dir, f"batch_log_{datetime.now().strftime('%Y%m%d_%H%M%S')}.txt")
    with open(log_path, "w", encoding="utf-8") as f:
        f.write("Batch AB feature extraction log\n")
        f.write(f"Correct folder: {correct_dir}\n")
        f.write(f"Incorrect folder: {incorrect_dir}\n")
        f.write(f"Output folder: {out_dir}\n\n")
        f.write(f"Matched subject pairs: {len(matched)}\n")
        if missing:
            f.write("\nMissing pairs:\n")
            for k in missing:
                f.write(f"  {k} | correct={k in correct_map} | incorrect={k in incorrect_map}\n")
        if errors:
            f.write("\nErrors:\n")
            for m in errors:
                f.write(f"  {m}\n")
        f.write("\nDone.\n")
    print("Saved log:", log_path)

    print("\nDone.")


if __name__ == "__main__":
    main()